from torchvision import datasets, transforms


def load_dataset(dataset_name, size=(32, 32)):
    if   dataset_name == 'mnist':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))])
        train_data = datasets.MNIST(root='data', train=True, download=True,
                                    transform=transform)
        test_data = datasets.MNIST(root='data', train=False, download=True,
                                   transform=transform)
        num_classes = 10
        num_input_channels = 1
    elif dataset_name == 'cifar10':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_data = datasets.CIFAR10(root='data', train=True,
                                      download=True, transform=transform)
        test_data = datasets.CIFAR10(root='data', train=False,
                                     download=True, transform=transform)
        num_classes = 10
        num_input_channels = 3
    elif dataset_name == 'cifar100':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_data = datasets.CIFAR100(root='data', train=True,
                                       download=True, transform=transform)
        test_data = datasets.CIFAR100(root='data', train=False,
                                      download=True, transform=transform)
        num_classes = 100
        num_input_channels = 3
    elif dataset_name == 'svhn':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        target_transform = lambda target: int(target[0]) - 1
        train_data = datasets.SVHN(root='data', split='train', download=True,
                                   transform=transform, target_transform=target_transform)
        test_data = datasets.SVHN(root='data', split='test', download=True,
                                  transform=transform, target_transform=target_transform)
        num_classes = 10
        num_input_channels = 3
    else:
        assert False
    return train_data, test_data, num_classes, num_input_channels
